# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import gzip
import os
import os.path as osp

import numpy as np

from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.converter import Converter
from datumaro.components.errors import MediaTypeError
from datumaro.components.extractor import DatasetItem, Importer, SourceExtractor
from datumaro.components.media import Image
from datumaro.util.meta_file_util import has_meta_file, parse_meta_file


class MnistPath:
    TEST_LABELS_FILE = "t10k-labels-idx1-ubyte.gz"
    TEST_IMAGES_FILE = "t10k-images-idx3-ubyte.gz"
    LABELS_FILE = "-labels-idx1-ubyte.gz"
    IMAGES_FILE = "-images-idx3-ubyte.gz"
    IMAGE_SIZE = 28
    NONE_LABEL = 255


class MnistExtractor(SourceExtractor):
    def __init__(self, path, subset=None):
        if not osp.isfile(path):
            raise FileNotFoundError("Can't read annotation file '%s'" % path)

        if not subset:
            file_name = osp.splitext(osp.basename(path))[0]
            if file_name.startswith("t10k"):
                subset = "test"
            else:
                subset = file_name.split("-", maxsplit=1)[0]

        super().__init__(subset=subset)
        self._dataset_dir = osp.dirname(path)

        self._categories = self._load_categories()

        self._items = list(self._load_items(path).values())

    def _load_categories(self):
        if has_meta_file(self._dataset_dir):
            return {
                AnnotationType.label: LabelCategories.from_iterable(
                    parse_meta_file(self._dataset_dir).keys()
                )
            }

        label_cat = LabelCategories()

        labels_file = osp.join(self._dataset_dir, "labels.txt")
        if osp.isfile(labels_file):
            with open(labels_file, encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    label_cat.add(line)
        else:
            for i in range(10):
                label_cat.add(str(i))

        return {AnnotationType.label: label_cat}

    def _load_items(self, path):
        items = {}
        with gzip.open(path, "rb") as lbpath:
            labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)

        meta = []
        metafile = osp.join(self._dataset_dir, self._subset + "-meta.gz")
        if osp.isfile(metafile):
            with gzip.open(metafile, "rb") as f:
                meta = np.frombuffer(f.read(), dtype="<U32")
            meta = meta.reshape(len(labels), int(len(meta) / len(labels)))

        # support for single-channel image only
        images = None
        images_file = osp.join(
            self._dataset_dir, osp.basename(path).replace("labels-idx1", "images-idx3")
        )
        if osp.isfile(images_file):
            with gzip.open(images_file, "rb") as imgpath:
                images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16)
                if len(meta) == 0 or len(meta[0]) < 2:
                    images = images.reshape(len(labels), MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE)

        pix_num = 0
        for i, annotation in enumerate(labels):
            annotations = []
            label = annotation
            if label != MnistPath.NONE_LABEL:
                annotations.append(Label(label))

            image = None
            if images is not None:
                if 0 < len(meta) and 1 < len(meta[i]):
                    h, w = int(meta[i][-2]), int(meta[i][-1])
                    image = images[pix_num : pix_num + h * w].reshape(h, w)
                    pix_num += h * w
                else:
                    image = images[i].reshape(MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE)

            if image is not None:
                image = Image(data=image)

            if 0 < len(meta) and (len(meta[i]) == 1 or len(meta[i]) == 3):
                i = meta[i][0]

            items[i] = DatasetItem(id=i, subset=self._subset, media=image, annotations=annotations)
        return items


class MnistImporter(Importer):
    @classmethod
    def find_sources(cls, path):
        return cls._find_sources_recursive(
            path,
            ".gz",
            "mnist",
            file_filter=lambda p: 1 < len(osp.basename(p).split("-"))
            and osp.basename(p).split("-")[1] == "labels",
        )


class MnistConverter(Converter):
    DEFAULT_IMAGE_EXT = ".png"

    def apply(self):
        if self._extractor.media_type() and not issubclass(self._extractor.media_type(), Image):
            raise MediaTypeError("Media type is not an image")

        os.makedirs(self._save_dir, exist_ok=True)
        if self._save_dataset_meta:
            self._save_meta_file(self._save_dir)

        for subset_name, subset in self._extractor.subsets().items():
            labels = []
            images = np.array([])
            item_ids = {}
            image_sizes = {}
            for item in subset:
                anns = [a.label for a in item.annotations if a.type == AnnotationType.label]
                label = 255
                if anns:
                    label = anns[0]
                labels.append(label)

                if item.id != str(len(labels) - 1):
                    item_ids[len(labels) - 1] = item.id

                if item.media and self._save_media:
                    image = item.media
                    if not image.has_data:
                        image_sizes[len(images) - 1] = [0, 0]
                    else:
                        image = image.data
                        if (
                            image.shape[0] != MnistPath.IMAGE_SIZE
                            or image.shape[1] != MnistPath.IMAGE_SIZE
                        ):
                            image_sizes[len(labels) - 1] = [image.shape[0], image.shape[1]]
                        images = np.append(images, image.reshape(-1).astype(np.uint8))

            if subset_name == "test":
                labels_file = osp.join(self._save_dir, MnistPath.TEST_LABELS_FILE)
            else:
                labels_file = osp.join(self._save_dir, subset_name + MnistPath.LABELS_FILE)
            self.save_annotations(labels_file, labels)

            if 0 < len(images):
                if subset_name == "test":
                    images_file = osp.join(self._save_dir, MnistPath.TEST_IMAGES_FILE)
                else:
                    images_file = osp.join(self._save_dir, subset_name + MnistPath.IMAGES_FILE)
                self.save_images(images_file, images)

            # it is't in the original format,
            # this is for storng other names and sizes of images
            if len(item_ids) or len(image_sizes):
                meta = []
                if len(item_ids) and len(image_sizes):
                    # other names and sizes of images
                    size = [MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE]
                    for i in range(len(labels)):
                        w, h = image_sizes.get(i, size)
                        meta.append([item_ids.get(i, i), w, h])

                elif len(item_ids):
                    # other names of images
                    for i in range(len(labels)):
                        meta.append([item_ids.get(i, i)])

                elif len(image_sizes):
                    # other sizes of images
                    size = [MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE]
                    for i in range(len(labels)):
                        meta.append(image_sizes.get(i, size))

                metafile = osp.join(self._save_dir, subset_name + "-meta.gz")
                with gzip.open(metafile, "wb") as f:
                    f.write(np.array(meta, dtype="<U32").tobytes())

        self.save_labels()

    def save_annotations(self, path, data):
        with gzip.open(path, "wb") as f:
            # magic number = 0x0801 (2049, hexadecimal representation)
            # this is used to verify the file with MNIST mark data
            f.write(np.array([0x0801, len(data)], dtype=">i4").tobytes())
            f.write(np.array(data, dtype="uint8").tobytes())

    def save_images(self, path, data):
        with gzip.open(path, "wb") as f:
            # magic number = 0x0803 (2051, hexadecimal representation),
            # this is used to verify the file with MNIST image data
            f.write(
                np.array(
                    [0x0803, len(data), MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE], dtype=">i4"
                ).tobytes()
            )
            f.write(np.array(data, dtype="uint8").tobytes())

    def save_labels(self):
        labels_file = osp.join(self._save_dir, "labels.txt")
        with open(labels_file, "w", encoding="utf-8") as f:
            f.writelines(
                l.name + "\n"
                for l in self._extractor.categories().get(AnnotationType.label, LabelCategories())
            )
